#ifndef __SOCKET_H__
#define __SOCKET_H__

#include"address.h"
#include"noncopyable.h"
#include "iomanager.h"
#include "fd_manager.h"
#include "log.h"
#include "macro.h"
#include "hook.h"

#include<netinet/tcp.h>
#include<sys/socket.h>
#include<sys/types.h>
#include<memory>

#include <openssl/err.h>
#include <openssl/ssl.h>

namespace sylar{

class Socket : public std::enable_shared_from_this<Socket>, Noncopyable{

public:
    using ptr = std::shared_ptr<Socket>;
    using weak_ptr = std::weak_ptr<Socket>;

    enum Type {
        TCP = SOCK_STREAM,
        UDP = SOCK_DGRAM
    };

    enum Family {
        IPv4 = AF_INET,
        IPv6 = AF_INET6,
        UNIX = AF_UNIX,
    };

    static Socket::ptr CreateTCP(sylar::Address::ptr address);
    static Socket::ptr CreateUDP(sylar::Address::ptr address);
    static Socket::ptr CreateTCPSocket();
    static Socket::ptr CreateUDPSocket();
    static Socket::ptr CreateTCPSocket6();
    static Socket::ptr CreateUDPSocket6();
    static Socket::ptr CreateUnixTCPSocket();
    static Socket::ptr CreateUnixUDPSocket();

    Socket(int family, int type, int protocol = 0);
    ~Socket();

    int64_t getSendTimeout();
    void setSendTimeout(int64_t v);

    int64_t getRecvTimeout();
    void setRecvTimeout(int64_t v);

    // 获取socket的设置
    bool getOption(int level, int option, void* result, socklen_t* len);
    template<typename T>
    bool getOption(int level, int option, T& result){
        size_t length = sizeof(T);
        return getOption(level, option, &result, &length);
    }

    bool setOption(int level, int option, const void* result, socklen_t length);
    template<class T>
    bool setOption(int level,int option,const T& value){
        return setOption(level,option,&value,sizeof(T));
    }

    // 绑定一个要连接的地址
    virtual bool bind(const Address::ptr addr);

    virtual Socket::ptr accept();
    /**
     * @brief 
     * 
     * @param backlog  SOMAXCONN系统自定义设置请求队列长度
     * @return true 
     * @return false 
     */
    virtual bool listen(int backlog = SOMAXCONN);

    /**
     * @brief 
     * 
     * @param addr 
     * @param timeout_ms -1表示不设置超时
     * @return true 
     * @return false 
     */
    virtual bool connect(const Address::ptr addr, uint64_t timeout_ms = -1);
    virtual bool reconnect(uint64_t timeout_ms = -1);

    virtual bool close();

    virtual int send(const void* buffer, size_t length, int flags = 0);
    virtual int send(const iovec* buffers, size_t length, int flags = 0);

    virtual int sendTo(const void* buffer, size_t length, const Address::ptr to, int flags = 0);
    virtual int sendTo(const iovec* buffers, size_t length, const Address::ptr to, int flags = 0);

    virtual int recv(void* buffer, size_t length, int flags = 0);
    virtual int recv(iovec* buffers, size_t length, int flags = 0);

    virtual int recvFrom(void* buffer, size_t length, Address::ptr from, int flags = 0);
    virtual int recvFrom(iovec* buffers, size_t length, Address::ptr from, int flags = 0);

    // socket连接的是哪个address
    Address::ptr getRemoteAddress();
    // socket连接的是哪个local端口
    Address::ptr getLocalAddress();

    // 通信协议 IPv4 IPv6
    int getFamily() const { return m_family;}
    // 通信类型SOCK_STREAM   SOCK_DGRAM
    int getType() const { return m_type;}
    // 某些通信类型有特定的协议
    int getProtocol() const { return m_protocol;}
    // 是否连接成功
    bool isConnected() const { return m_isConnected;}
    // 是否有效
    bool isValid() const;
    int getError();

    // 输出socket信息
    virtual std::ostream& dump(std::ostream& os) const;
    virtual std::string toString() const;

    // 返回socket句柄
    int getSocket() const { return m_sock;}

    // 用来强制唤醒
    bool cancelRead();
    bool cancelWrite();
    bool cancelAccept();
    bool cancelAll();

protected:

    void initSock();
    void newSock();
    // 句柄转换为socket
    virtual bool init(int sock);

protected:
    int m_sock;
    int m_family;
    int m_type;
    int m_protocol;
    int m_isConnected;

    Address::ptr m_localAddress;
    Address::ptr m_remoteAddress;

private:

};

class SSLSocket : public Socket {
public:
    typedef std::shared_ptr<SSLSocket> ptr;

    static SSLSocket::ptr CreateTCP(sylar::Address::ptr address);
    static SSLSocket::ptr CreateTCPSocket();
    static SSLSocket::ptr CreateTCPSocket6();

    SSLSocket(int family, int type, int protocol = 0);
    virtual Socket::ptr accept() override;
    virtual bool bind(const Address::ptr addr) override;
    virtual bool connect(const Address::ptr addr, uint64_t timeout_ms = -1) override;
    virtual bool listen(int backlog = SOMAXCONN) override;
    virtual bool close() override;
    virtual int send(const void* buffer, size_t length, int flags = 0) override;
    virtual int send(const iovec* buffers, size_t length, int flags = 0) override;
    virtual int sendTo(const void* buffer, size_t length, const Address::ptr to, int flags = 0) override;
    virtual int sendTo(const iovec* buffers, size_t length, const Address::ptr to, int flags = 0) override;
    virtual int recv(void* buffer, size_t length, int flags = 0) override;
    virtual int recv(iovec* buffers, size_t length, int flags = 0) override;
    virtual int recvFrom(void* buffer, size_t length, Address::ptr from, int flags = 0) override;
    virtual int recvFrom(iovec* buffers, size_t length, Address::ptr from, int flags = 0) override;

    bool loadCertificates(const std::string& cert_file, const std::string& key_file);
    virtual std::ostream& dump(std::ostream& os) const override;
protected:
    virtual bool init(int sock) override;
private:
    std::shared_ptr<SSL_CTX> m_ctx;
    std::shared_ptr<SSL> m_ssl;
};


/**
 * @brief 流式输出socket
 * @param[in, out] os 输出流
 * @param[in] sock Socket类
 */
std::ostream& operator<<(std::ostream& os, const Socket& sock);

}



#endif